import os
import gzip
from collections import Counter
from scipy.stats import mannwhitneyu, linregress
from scipy.special import ndtr
from matplotlib.ticker import FormatStrFormatter
from pylab import *
from numpy import *


random.seed(12345)

halfwindow = 5000

def read_chromosome_sizes(assembly="hg38"):
    sizes = {}
    directory = "/osc-fs_home/scratch/mdehoon/Data/Genomes"
    subdirectory = assembly
    filename = "%s.chrom.sizes" % assembly
    path = os.path.join(directory, subdirectory, filename)
    print("Reading chromosome sizes from", path)
    handle = open(path)
    for line in handle:
        chromosome, size = line.split()
        sizes[chromosome] = int(size)
    handle.close()
    return sizes

def parse_chipseq(sizes, marker):
    directory = "/osc-fs_home/mdehoon/Data/CASPARs/ChIPSeq"
    if marker == "H3K4me1":
        filename = "GSM5379664_WT_H3K4me1_peaks.narrowPeak.gz"
    elif marker == "H3K4me2":
        filename = "GSM5379668_WT_H3K4me2APR_peaks.narrowPeak.gz"
    elif marker == "H3K4me3":
        filename = "GSM5379671_WT_H3K4me3_peaks.narrowPeak.gz"
    else:
        raise ValueError("Unknown marker %s" % marker)
    path = os.path.join(directory, filename)
    print("Reading", path)
    stream = gzip.open(path, "rt")
    current = None
    for line in stream:
        words = line.split()
        chromosome = words[0]
        start = int(words[1])
        end = int(words[2])
        score = float(words[6])
        if chromosome != current:
            if current is not None:
                yield (current, scores)
            current = chromosome
            size = sizes[chromosome]
            scores = zeros(size)
        scores[start:end] = score
    if current is not None:
        yield (current, scores)

def read_significance():
    filename = 'enhancers.deseq.txt'
    print("Reading %s" % filename)
    handle = open(filename)
    line = next(handle)
    words = line.split()
    assert words[0] == "enhancer"
    assert words[1] == "00hr_basemean"
    assert words[2] == "00hr_log2fc"
    assert words[3] == "00hr_pvalue"
    assert words[4] == "01hr_basemean"
    assert words[5] == "01hr_log2fc"
    assert words[6] == "01hr_pvalue"
    assert words[7] == "04hr_basemean"
    assert words[8] == "04hr_log2fc"
    assert words[9] == "04hr_pvalue"
    assert words[10] == "12hr_basemean"
    assert words[11] == "12hr_log2fc"
    assert words[12] == "12hr_pvalue"
    assert words[13] == "24hr_basemean"
    assert words[14] == "24hr_log2fc"
    assert words[15] == "24hr_pvalue"
    assert words[16] == "96hr_basemean"
    assert words[17] == "96hr_log2fc"
    assert words[18] == "96hr_pvalue"
    assert words[19] == "all_basemean"
    assert words[20] == "all_log2fc"
    assert words[21] == "all_pvalue"
    basemeans = {}
    ppvalues = {}
    logfc = {}
    for line in handle:
        words = line.split()
        assert len(words) == 22
        name = words[0]
        basemean = float(words[19])
        basemeans[name] = basemean
        log2fc = float(words[20])
        pvalue = float(words[21])
        ppvalue = -log10(pvalue) * sign(log2fc)
        ppvalues[name] = ppvalue
        if pvalue < 0.05:
            logfc[name] = log2fc
    handle.close()
    return ppvalues, logfc

def read_expression(ppvalues):
    filename = "enhancers.expression.txt"
    print("Reading", filename)
    handle = open(filename)
    line = next(handle)
    words = line.split()
    assert words[0] == 'enhancer'
    assert words[1] == 'HiSeq_t00_r1'
    assert words[2] == 'HiSeq_t00_r2'
    assert words[3] == 'HiSeq_t00_r3'
    assert words[4] == 'HiSeq_t01_r1'
    assert words[5] == 'HiSeq_t01_r2'
    assert words[6] == 'HiSeq_t04_r1'
    assert words[7] == 'HiSeq_t04_r2'
    assert words[8] == 'HiSeq_t04_r3'
    assert words[9] == 'HiSeq_t12_r1'
    assert words[10] == 'HiSeq_t12_r2'
    assert words[11] == 'HiSeq_t12_r3'
    assert words[12] == 'HiSeq_t24_r1'
    assert words[13] == 'HiSeq_t24_r2'
    assert words[14] == 'HiSeq_t24_r3'
    assert words[15] == 'HiSeq_t96_r1'
    assert words[16] == 'HiSeq_t96_r2'
    assert words[17] == 'HiSeq_t96_r3'
    assert words[18] == 'CAGE_00_hr_A'
    assert words[19] == 'CAGE_00_hr_C'
    assert words[20] == 'CAGE_00_hr_G'
    assert words[21] == 'CAGE_00_hr_H'
    assert words[22] == 'CAGE_01_hr_A'
    assert words[23] == 'CAGE_01_hr_C'
    assert words[24] == 'CAGE_01_hr_G'
    assert words[25] == 'CAGE_04_hr_C'
    assert words[26] == 'CAGE_04_hr_E'
    assert words[27] == 'CAGE_12_hr_A'
    assert words[28] == 'CAGE_12_hr_C'
    assert words[29] == 'CAGE_24_hr_C'
    assert words[30] == 'CAGE_24_hr_E'
    assert words[31] == 'CAGE_96_hr_A'
    assert words[32] == 'CAGE_96_hr_C'
    assert words[33] == 'CAGE_96_hr_E'
    assert len(words) == 34
    hiseq = {}
    cage = {}
    for line in handle:
        words = line.split()
        assert len(words) == 34
        name = words[0]
        dataset, locus = name.split("|")
        if dataset == "StartSeq":
            continue
        assert name in ppvalues
        hiseq_count = 0
        for word in words[1:18]:
            count1, count2 = word.split(",")
            hiseq_count += int(count1)
            hiseq_count += int(count2)
        hiseq[name] = hiseq_count
        cage_count = 0
        for word in words[18:34]:
            count1, count2 = word.split(",")
            cage_count += int(count1)
            cage_count += int(count2)
        cage[name] = cage_count
    handle.close()
    return cage, hiseq

def read_enhancer_loci():
    ppvalues, logfc = read_significance()
    cage, hiseq = read_expression(ppvalues)
    loci = {}
    ppvalue_threshold = -log10(0.05)
    names = {'short': [], 'long': []}
    for name in ppvalues:
        log2fc= logfc.get(name)
        if log2fc is None:
            continue
        assert abs(ppvalues[name]) > ppvalue_threshold
        if log2fc > 0:
            key = "short"
        else:
            key = "long"
        names[key].append(name)
    names['short'].sort(key=hiseq.get, reverse=True)
    names['long'].sort(key=cage.get, reverse=True)
    assert len(names['long']) > len(names['short'])
    n = len(names['short'])
    loci = {key: {chromosome: {} for chromosome in sizes} for key in ("short", "long")}
    for key in ("short", "long"):
        for i, name in enumerate(names[key]):
            source, locus = name.split("|")
            assert source in ("CAGE", "HiSeq", "Both", "FANTOM5")
            chromosome, start_end = locus.split(":")
            start, end = start_end.split("-")
            start = int(start)
            end = int(end)
            locus = (start, end)
            loci[key][chromosome][name] = locus
    return loci, names

def calculate_background(sizes, marker):
    window = 2 * halfwindow + 1
    n = 10000
    chromosomes = ["chr" + str(i) for i in list(range(1, 23)) + ['X', 'Y']]
    p = array([sizes[chromosome] for chromosome in chromosomes], float)
    p /= sum(p)
    chromosome_counts = multinomial(n, p)
    background_scores = zeros(n)
    chipseq = parse_chipseq(sizes, marker)
    i = 0
    for chromosome, scores in chipseq:
        try:
            index = chromosomes.index(chromosome)
        except ValueError:
            continue
        count = chromosome_counts[index]
        start = halfwindow
        end = sizes[chromosome] - halfwindow - 1
        for index in random.randint(start, end, count):   
            score = sum(scores[index-halfwindow:index+halfwindow+1])
            background_scores[i] = score
            i += 1
    assert i == n
    return background_scores

def calculate_profile(sizes, loci, marker):
    profiles = {}
    count = {}
    locus_scores = {}
    for category in loci:
        profiles[category] = zeros(2*halfwindow+1)
        count[category] = zeros(2*halfwindow+1, int)
        locus_scores[category] = {}
    chipseq = parse_chipseq(sizes, marker)
    for chromosome, scores in chipseq:
        size = sizes[chromosome]
        for category in loci:
            for name in loci[category][chromosome]:
                start, end = loci[category][chromosome][name]
                position = (start + end) // 2
                if position < halfwindow:
                    values = scores[0:position+halfwindow+1]
                    profiles[category][-len(values):] += values
                    count[category][-len(values):] += 1
                elif position + halfwindow > size:
                    values = scores[position-halfwindow:]
                    profiles[category][:len(values)] += values
                    count[category][:len(values)] += 1
                else:
                    values = scores[position-halfwindow:position+halfwindow+1]
                    profiles[category] += values
                    count[category][:] += 1
                locus_score = mean(values)
                locus_scores[category][name] = locus_score
    for category in profiles:
        profiles[category] /= count[category]
    return profiles, locus_scores

sizes = read_chromosome_sizes(assembly="hg38")

loci, names = read_enhancer_loci()

categories = tuple(loci.keys())

markers = ("H3K4me1", "H3K4me3")

panels = []

x = arange(-halfwindow, halfwindow+1)
f = figure(figsize=(6,5))

ax = f.add_subplot(111)
ax.spines['top'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['left'].set_color('none')
ax.spines['right'].set_color('none')
ax.tick_params(labelcolor='w', top=False, bottom=False, left=False, right=False)
ax.set_xlabel("Position relative to enhancer center [bp]", fontsize=8)
ax.set_ylabel("Average MACS2 signal enrichment", fontsize=8, labelpad=20)


scores = {}
for i, marker in enumerate(markers):
    row = []
    panels.append(row)
    profiles, scores[marker] = calculate_profile(sizes, loci, marker)
    background_scores = calculate_background(sizes, marker)
    for j, category in enumerate(categories):
        foreground_scores = array(list(scores[marker][category].values()))
        u, p = mannwhitneyu(foreground_scores, background_scores)
        print("%s, %s: Mann-Whitney p = %g" % (marker, category, p))
        panel = f.add_subplot(2, 2, 2*i+j+1)
        row.append(panel)
        if category.startswith("long"):
            color = "blue"
        elif category == "short":
            color = "red"
        else:
            raise Exception("Unknown category %s" % category)
        plot(x, profiles[category], color=color)
        yticks(fontsize=8)
        if j == 0:
            ylabel(marker)
        if i == 1:
            xticks(fontsize=8)
        else:
            xticks([])
        if i == 0:
            n = sum([len(loci[category][chromosome]) for chromosome in sizes])
            if category == "short":
                protocol = "single-end"
            elif category == "long":
                protocol = "CAGE"
            title("Significantly enriched\nin %s capped RNA\n(%s) libraries\n($N = %d$)" % (category, protocol, n), fontsize=8)
    for panel in row:
        panel.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))

x = array([scores["H3K4me1"]["short"][name] for name in names["short"]])
y = array([scores["H3K4me3"]["short"][name] for name in names["short"]])
result = linregress(x, y)
slope_short = result.slope
slope_short_stderr = result.stderr

x = array([scores["H3K4me1"]["long"][name] for name in names["long"]])
y = array([scores["H3K4me3"]["long"][name] for name in names["long"]])
result = linregress(x, y)
slope_long = result.slope
slope_long_stderr = result.stderr

slope_stderr = hypot(slope_short_stderr, slope_long_stderr)
slope_difference = slope_long - slope_short
slope_ratio = slope_long / slope_short
slope_zscore = slope_difference / slope_stderr

pvalue = 2*ndtr(-slope_zscore)

print("H3K4me3/H3K4me1 enrichment: ratio = %f comparing long vs short, p-value = %g" % (slope_ratio, pvalue))

subplots_adjust(left=0.20,bottom=0.20,right=0.80)

filename = "figure_enhancer_chipseq.svg"
print("Saving figure as", filename)
savefig(filename)

filename = "figure_enhancer_chipseq.png"
print("Saving figure as", filename)
savefig(filename)
